Copying from https://github.com/google/jax/issues/20750:
import jax
import jax.numpy as jnp
def test_func(x, y):
return x, y
def main():
# Print available JAX devices
print("JAX devices:", jax.devices())
# Create two random matrices
a = jnp.array([[1.0, 2.0], [3.0, 4.0]])
b = jnp.array([[5.0, 6.0], [7.0, 8.0]])
# Perform matrix multiplication
c = jnp.dot(a, b)
# Print the result
print("Result of matrix multiplication:")
print(c)
# Compute the gradient of sum of c with respect to a
grad_a = jax.grad(lambda a: jnp.sum(jnp.dot(a, b)))(a)
print("Gradient with respect to a:")
print(grad_a)
rng = jax.random.PRNGKey(0)
test_input = jax.random.normal(key=rng, shape=(5,5,5))
initial_state = jax.numpy.array(0.0)
x, y = jax.lax.scan(test_func, initial_state, test_input)
if __name__ == "__main__":
main()
Gets:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-04-15 18:22:28.994752: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M2 Pro
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
JAX devices: [METAL(id=0)]
Result of matrix multiplication:
[[19. 22.]
[43. 50.]]
Gradient with respect to a:
[[11. 15.]
[11. 15.]]
zsh: segmentation fault python JAXTest.py
With more info from the debugger:
Current thread 0x00000001fdd3bac0 (most recent call first):
File "/Users/.../anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1213 in __call__
My configuration is: jax-metal : 0.0.6 jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.24.3 python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64') macOS 14.4.1 (23E224)
Before in 3.9+0.0.3 etc it wasn't happening.